{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hv7eUTNypW-X"
      },
      "source": [
        "<span style=\"color: red;\">Requirement when running in Goolge Colab</span>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8TBCfleupW-Z"
      },
      "outputs": [],
      "source": [
        "!pip install diffusers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V_x1-kHhpW-a"
      },
      "source": [
        "#  Chapter 7 - Stable Face\n",
        "\n",
        "Now that we have the ability to reconstruct real-world images in Stable Diffusion and also learnt how to replace attention layers, we reach the season finale, to combine these methods to modify any facial features arbitarily."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AvUa55nHpW-b"
      },
      "source": [
        "## Set up the Stable Diffusion pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KjJTEyyspW-b"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from typing import Optional\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
        "\n",
        "pipe = StableDiffusionPipeline.from_pretrained(model_id)\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n",
        "scheduler = DDIMScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n",
        "\n",
        "prompt = \"A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "prompt_embeds = pipe.encode_prompt(prompt=prompt, negative_prompt=\"\", device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True)\n",
        "\n",
        "cond_prompt_embeds = prompt_embeds[0]\n",
        "uncond_prompt_embeds = prompt_embeds[1]\n",
        "\n",
        "prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])\n",
        "\n",
        "initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(22)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LumEzIEapW-c"
      },
      "source": [
        "## Load Training Data\n",
        "\n",
        "if you saved the training data from previous chapter you can uncomment this code and run this cell and skip the Training Section"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5fjwn0zPpW-d"
      },
      "outputs": [],
      "source": [
        "# loaded_data = torch.load(\"reconstruction_data.pt\")\n",
        "\n",
        "# real_image_initial_latents = loaded_data['initial_latent']\n",
        "# QT = loaded_data['QT']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4cXX-QV3pW-d"
      },
      "source": [
        "## Training Section\n",
        "#### Skip to Cross Attention Replacement if you loaded the data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9SLcEMo5pW-d"
      },
      "outputs": [],
      "source": [
        "import torchvision\n",
        "from PIL import Image\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import gc\n",
        "import requests\n",
        "\n",
        "url1 = 'https://raw.githubusercontent.com/OutofAi/StableFace/main/photo.png'\n",
        "filename1 = url1.split('/')[-1]\n",
        "response1 = requests.get(url1)\n",
        "with open(filename1, 'wb') as f:\n",
        "    f.write(response1.content)\n",
        "\n",
        "\n",
        "img = Image.open('photo.png')\n",
        "\n",
        "transform = torchvision.transforms.Compose([\n",
        "    torchvision.transforms.Resize((512, 512)),\n",
        "    torchvision.transforms.ToTensor()\n",
        "])\n",
        "\n",
        "loaded_image = transform(img).to(\"cuda\").unsqueeze(0)\n",
        "\n",
        "if loaded_image.shape[1] == 4:\n",
        "    loaded_image = loaded_image[:,:3,:,:]\n",
        "\n",
        "with torch.no_grad():\n",
        "    encoded_image = pipe.vae.encode(loaded_image*2 - 1)\n",
        "    real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()\n",
        "\n",
        "num_inference_steps = 10\n",
        "\n",
        "guidance_scale = 1\n",
        "inverse_scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = inverse_scheduler.timesteps\n",
        "\n",
        "latents = real_image_latents\n",
        "\n",
        "inversed_latents = []\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        inversed_latents.append(latents)\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
        "\n",
        "\n",
        "# initial state\n",
        "real_image_initial_latents = latents\n",
        "\n",
        "\n",
        "W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)\n",
        "QT = nn.Parameter(W_values.clone())\n",
        "\n",
        "\n",
        "guidance_scale = 7.5\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "optimizer = torch.optim.AdamW([QT], lr=0.008)\n",
        "\n",
        "pipe.vae.eval()\n",
        "pipe.vae.requires_grad_(False)\n",
        "pipe.unet.eval()\n",
        "pipe.unet.requires_grad_(False)\n",
        "\n",
        "last_loss = 1\n",
        "\n",
        "for epoch in range(50):\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    intermediate_values = real_image_initial_latents.clone().requires_grad_(True)\n",
        "\n",
        "    if last_loss < 0.02:\n",
        "        break\n",
        "    elif last_loss < 0.03:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = 0.003\n",
        "    elif last_loss < 0.035:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = 0.006\n",
        "\n",
        "    for i in range(num_inference_steps):\n",
        "        latents = intermediate_values.detach().clone().requires_grad_(True)\n",
        "\n",
        "        t = timesteps[i]\n",
        "\n",
        "        prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred_model = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
        "\n",
        "        loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction=\"mean\")\n",
        "        last_loss = loss\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        print(f\"Loss (epoch {epoch} - Step {i}): {loss.item()}\")\n",
        "    print(f\"Reconstruction Loss (epoch {epoch}): {last_loss.item()}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YYrQt116pW-e"
      },
      "source": [
        "## Configure Attention Replacement Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UIYeftv7pW-f"
      },
      "outputs": [],
      "source": [
        "\n",
        "def contextual_forward(self, source_batch_index, target_batch_index, skip_dimension, should_replace = False):\n",
        "\n",
        "    def forward_modified(\n",
        "        hidden_states: torch.FloatTensor,\n",
        "        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n",
        "        attention_mask: Optional[torch.FloatTensor] = None,\n",
        "        temb: Optional[torch.FloatTensor] = None,\n",
        "    ) -> torch.FloatTensor:\n",
        "\n",
        "            residual = hidden_states\n",
        "\n",
        "            is_cross = not encoder_hidden_states is None\n",
        "\n",
        "            input_ndim = hidden_states.ndim\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                batch_size, channel, height, width = hidden_states.shape\n",
        "                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
        "\n",
        "            batch_size, _, _ = (\n",
        "                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
        "            )\n",
        "\n",
        "            if self.group_norm is not None:\n",
        "                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
        "\n",
        "            query = self.to_q(hidden_states)\n",
        "\n",
        "            if encoder_hidden_states is None:\n",
        "                encoder_hidden_states = hidden_states\n",
        "            elif self.norm_cross:\n",
        "                encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)\n",
        "\n",
        "            key = self.to_k(encoder_hidden_states)\n",
        "            value = self.to_v(encoder_hidden_states)\n",
        "\n",
        "            query = self.head_to_batch_dim(query)\n",
        "            key = self.head_to_batch_dim(key)\n",
        "            value = self.head_to_batch_dim(value)\n",
        "\n",
        "            attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))\n",
        "\n",
        "            #############################################################\n",
        "            ### The replacing process of attention maps happens here    ###\n",
        "            #############################################################\n",
        "\n",
        "            # each attention_scores comes with a batch shape of heads * batch size and considering\n",
        "            # the unconditional prompt that we require for CFG technically if couple them up as one\n",
        "            # batch then technically it comes in the format of heads * batch size * 2 and we use that\n",
        "            # information to find the relevant indicies to each batch\n",
        "            num_heads = self.heads\n",
        "            source_starting_index = num_heads * source_batch_index * 2\n",
        "            source_ending_index = num_heads * (source_batch_index + 1) * 2\n",
        "            target_starting_index = num_heads * target_batch_index * 2\n",
        "            target_ending_index = num_heads * (target_batch_index + 1) * 2\n",
        "\n",
        "            dimension_squared = hidden_states.shape[1]\n",
        "\n",
        "            # our experiement showed that this is the combination granted the best results when it comes\n",
        "            # to facial related changes, hence this specific configuration for our replacement\n",
        "            if not is_cross and (should_replace or not dimension_squared == skip_dimension * skip_dimension):\n",
        "                attention_scores[target_starting_index:target_ending_index,:,:] = attention_scores[source_starting_index:source_ending_index,:,:]\n",
        "            #############################################################\n",
        "            attention_probs = attention_scores.softmax(dim=-1)\n",
        "            del attention_scores\n",
        "\n",
        "            hidden_states = torch.bmm(attention_probs, value)\n",
        "            hidden_states = self.batch_to_head_dim(hidden_states)\n",
        "            del attention_probs\n",
        "\n",
        "            hidden_states = self.to_out[0](hidden_states)\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
        "\n",
        "            if self.residual_connection:\n",
        "                hidden_states = hidden_states + residual\n",
        "\n",
        "            hidden_states = hidden_states / self.rescale_output_factor\n",
        "\n",
        "            return hidden_states\n",
        "\n",
        "    return forward_modified\n",
        "\n",
        "def apply_forward_function(unet, child = None, source_batch_index=-1, target_batch_index=-1, should_replace= False):\n",
        "    if child == None:\n",
        "        children = unet.named_children()\n",
        "        for child in children:\n",
        "            apply_forward_function(unet, child[1], source_batch_index, target_batch_index, should_replace)\n",
        "    else:\n",
        "        if child.__class__.__name__ == 'Attention':\n",
        "            child.forward = contextual_forward(child, source_batch_index, target_batch_index, pipe.unet.config.sample_size, should_replace)\n",
        "        elif hasattr(child, 'children'):\n",
        "            for sub_child in child.children():\n",
        "                block_name = child.__class__.__name__\n",
        "\n",
        "                if \"Down\" in block_name:\n",
        "                    should_replace = True\n",
        "                elif \"Up\" in block_name:\n",
        "                    should_replace = False\n",
        "                elif \"Mid\" in block_name:\n",
        "                    should_replace = True\n",
        "                apply_forward_function(unet, sub_child, source_batch_index, target_batch_index, should_replace)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BhZvi7eCpW-g"
      },
      "source": [
        "## Create a Display Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eBzNJoAmpW-h"
      },
      "outputs": [],
      "source": [
        "def display_latents(latents):\n",
        "    with torch.no_grad():\n",
        "        image_0 = pipe.vae.decode(latents[0].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()\n",
        "        image_np_0 = (image_np_0 - image_np_0.min()) / (image_np_0.max() - image_np_0.min())\n",
        "\n",
        "        image_1 = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()\n",
        "        image_np_1 = (image_np_1 - image_np_1.min()) / (image_np_1.max() - image_np_1.min())\n",
        "\n",
        "        fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
        "\n",
        "        axes[0].imshow(image_np_0)\n",
        "        axes[0].axis('off')\n",
        "        axes[0].set_title('Latent 0')\n",
        "\n",
        "        axes[1].imshow(image_np_1)\n",
        "        axes[1].axis('off')\n",
        "        axes[1].set_title('Latent 1')\n",
        "\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A2-9MoJ9pW-h"
      },
      "source": [
        "# Apply Attention Replacement"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0oRQ-FqUpW-h"
      },
      "outputs": [],
      "source": [
        "new_prompt = \"A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "\n",
        "guidance_scale = 7.5\n",
        "num_inference_steps = 10\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([real_image_initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds, QT[i].unsqueeze(0), new_prompt_embeds[0] ])\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=modified_prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YRjWjT2ZpW-i"
      },
      "source": [
        "## Other prompts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lKuEcZw4pW-i"
      },
      "outputs": [],
      "source": [
        "new_prompt = \"A photo of a male, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "\n",
        "guidance_scale = 7.5\n",
        "num_inference_steps = 10\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([real_image_initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds, QT[i].unsqueeze(0), new_prompt_embeds[0] ])\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=modified_prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BNqtoL_3pW-i"
      },
      "outputs": [],
      "source": [
        "new_prompt = \"A photo of a woman, wavy hair, light blonde and pink hair, smiling expression, grey background, closed eyes\"\n",
        "\n",
        "\n",
        "new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "\n",
        "guidance_scale = 7.5\n",
        "num_inference_steps = 10\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([real_image_initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds, QT[i].unsqueeze(0), new_prompt_embeds[0] ])\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=modified_prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K0X_vxvMpW-j"
      },
      "outputs": [],
      "source": [
        "new_prompt = \"A photo of a woman, wavy hair, dark blonde and pink hair, smiling expression, grey background, bangs\"\n",
        "\n",
        "\n",
        "new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "\n",
        "guidance_scale = 7.5\n",
        "num_inference_steps = 10\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([real_image_initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds, QT[i].unsqueeze(0), new_prompt_embeds[0] ])\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=modified_prompt_embeds,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_cHYIxhRpW-j"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "L4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
